{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "692646c3",
   "metadata": {},
   "source": [
    "## Out-of-Core Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79d62d75",
   "metadata": {},
   "source": [
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "\n",
    "Out-of-core learning refers to the process of training a model on an amount of data that cannot fit in memory. There are several approaches that can be described as out-of-core, but here we refer to the ability to derive exact updates to a model from a massive data set, despite not being able to fit the entire thing in memory.\n",
    "\n",
    "This out-of-core learning approach is implemented for all of pomegranate's models using two methods. The first is a summarize method that will take in a batch of data and reduce it down to additive sufficient statistics. Because these summaries are additive, after the first call, these summaries are added to the previously stored summaries. Once the entire data set has been seen, the stored sufficient statistics will be identical to those that would have been derived if the entire data set had been seen at once. The second method is the from_summaries method, which uses the stored sufficient statistics to derive parameter updates for the model.\n",
    "\n",
    "A common solution to having too much data is to randomly select an amount of data that does fit in memory to use in the place of the full data set. While simple to implement, this approach is likely to yield lower performance models because it is exposed to less data. However, by using out-of-core learning, on can train their models on a massive amount of data without being limited by the amount of memory their computer has."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "732d90aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n",
      "torch      : 1.13.0\n",
      "pomegranate: 1.0.0\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-208-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import torch\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e77be408",
   "metadata": {},
   "source": [
    "### `summarize ` and `from_summaries`\n",
    "\n",
    "Let's start off simple with training a normal distribution in an out-of-core manner. First, we'll generate some random data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "40c81d88",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.randn(1000, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68782e46",
   "metadata": {},
   "source": [
    "Then, we can initialize a distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fec969dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pomegranate.distributions import Normal\n",
    "\n",
    "dist = Normal()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e18b3d50",
   "metadata": {},
   "source": [
    "Now let's summarize through a few batches of data using the `summarize` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d181be6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dist.summarize(X[:200])\n",
    "dist.summarize(X[200:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6df91e38",
   "metadata": {},
   "source": [
    "Importantly, summarizing data doesn't update parameters by itself. Rather, it extracts additive sufficient statistics from the data. Each time `summarize` is called, these statistics are added to the previously aggregated statistics.\n",
    "\n",
    "In order to update the parameters of the model, you need to call the `from_summaries` method. This method updates the parameters of the model given the stored sufficient statistics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9cbbe4dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([ 0.0175,  0.0096,  0.0228,  0.0592, -0.0089]),\n",
       " Parameter containing:\n",
       " tensor([[ 0.9786, -0.0106,  0.0344,  0.0571,  0.0330],\n",
       "         [-0.0106,  0.9970,  0.0165, -0.0330,  0.0021],\n",
       "         [ 0.0344,  0.0165,  0.9405, -0.0075, -0.0374],\n",
       "         [ 0.0571, -0.0330, -0.0075,  1.0399,  0.0333],\n",
       "         [ 0.0330,  0.0021, -0.0374,  0.0333,  0.9978]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dist.from_summaries()\n",
    "dist.means, dist.covs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ae90d8d",
   "metadata": {},
   "source": [
    "This update is exactly the same as one would get if they had trained on the entire data set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c33e1a42",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([ 0.0175,  0.0096,  0.0228,  0.0592, -0.0089]),\n",
       " Parameter containing:\n",
       " tensor([[ 0.9786, -0.0106,  0.0344,  0.0571,  0.0330],\n",
       "         [-0.0106,  0.9970,  0.0165, -0.0330,  0.0021],\n",
       "         [ 0.0344,  0.0165,  0.9405, -0.0075, -0.0374],\n",
       "         [ 0.0571, -0.0330, -0.0075,  1.0399,  0.0333],\n",
       "         [ 0.0330,  0.0021, -0.0374,  0.0333,  0.9978]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dist = Normal()\n",
    "dist.summarize(X)\n",
    "dist.from_summaries()\n",
    "dist.means, dist.covs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b217107",
   "metadata": {},
   "source": [
    "### Batched Training\n",
    "\n",
    "Sometimes your data is so large that it cannot fit in memory (either CPU or GPU). In these cases, we can use the out-of-core API to train on batches at a time. This is similar to how neural networks are trained except that, rather than updating after each batch (or aggregating gradients over a small number of batches), we can summarize over a much larger number of batches -- potentially even the entire data set to get an exact update. Let's see an example of how that might work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a6232d3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dist = Normal()\n",
    "\n",
    "for i in range(10):\n",
    "    X_batch = torch.randn(1000, 20) # This is meant to mimic loading a batch of data\n",
    "    dist.summarize(X_batch)\n",
    "    del X_batch # Now we can discard the batch \n",
    "    \n",
    "dist.from_summaries()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c30c9f4",
   "metadata": {},
   "source": [
    "Batched training is easy to implement for simple probability distributions but it can also be done with more complicated models if you want to code your own expectation-maximization. For instance, let's try training a mixture model using a modified version of the training code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "14012265",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] Improvement: 1945.53125, Time: 0.01443s\n",
      "[2] Improvement: 99.875, Time: 0.01562s\n",
      "[3] Improvement: 34.1875, Time: 0.01019s\n",
      "[4] Improvement: 17.65625, Time: 0.00994s\n"
     ]
    }
   ],
   "source": [
    "from pomegranate.gmm import GeneralMixtureModel\n",
    "\n",
    "X = torch.randn(10000, 20)\n",
    "\n",
    "model = GeneralMixtureModel([Normal(), Normal()])\n",
    "\n",
    "logp = None\n",
    "for i in range(5):\n",
    "    start_time = time.time()\n",
    "\n",
    "    last_logp = logp\n",
    "    \n",
    "    logp = 0\n",
    "    for j in range(0, X.shape[0], 1000): # Train on batches of size 1000\n",
    "        logp += model.summarize(X[j:j+1000])\n",
    "\n",
    "    if i > 0:\n",
    "        improvement = logp - last_logp\n",
    "        duration = time.time() - start_time\n",
    "        print(\"[{}] Improvement: {}, Time: {:4.4}s\".format(i, improvement, duration))\n",
    "\n",
    "    model.from_summaries()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
